import json
import numpy as np
import optax
import pandas as pd

from argparse import Namespace

from collections import OrderedDict
from src.trainers.base_trainer import BaseTrainer
from src.tasks.nextcharprediction import *
from src.predictors import *
from src.tasks.envs.tracepatterning import compute_return_error
from src.dataloaders import NextCharDataLoader,AnimalBehaviourDataLoader,MiniGridDataLoader



def initialize_model(args):
    if args.model=='lt_tbptt':
        model=LinearTransformerPredictor(d_model=args.d_model,n_heads=args.n_heads,d_ffc=args.d_ffc,
                                        n_layers=args.n_layers,output_size=args.output_size,truncation=args.truncation,kernel_phi=eluplus1)
    elif args.model=='universal_linear_transformer':
        model=UniversalTransformerPredictor(d_model=args.d_model,n_heads=args.n_heads,d_ffc=args.d_ffc,
                                        n_layers=args.n_layers,output_size=args.output_size,truncation=args.truncation,kernel_phi=eluplus1)
    elif args.model=='oult_tbppt':    
        model=OULTTBBPTPredictor(n_layers=args.n_layers,d_model=args.d_model,d_ffc=args.d_model,n_heads=args.n_heads,
                                            kernel_dim=args.d_model,truncation=args.truncation,output_size=args.output_size,
                                            pos_emb_type=args.pos_emb_type,use_layer_emb=args.use_layer_emb,kernel_phi=eluplus1,update_rule=args.update_rule)
    elif args.model=='rlt_tbptt':
        model=RLTTBBPTPredictor(n_layers=args.n_layers,d_model=args.d_model,d_ffc=args.d_model,n_heads=args.n_heads,
                                            kernel_dim=args.d_model,truncation=args.truncation,output_size=args.output_size,
                                        pos_emb_type=args.pos_emb_type,use_layer_emb=args.use_layer_emb,no_memory=args.no_memory,kernel_phi=eluplus1,update_rule=args.update_rule)
    elif args.model=='rlt_tuoro':
        model=RLTTUOROPredictor(n_layers=args.n_layers,d_model=args.d_model,d_ffc=args.d_model,n_heads=args.n_heads,
                                            kernel_dim=args.d_model,truncation=args.truncation,output_size=args.output_size,
                                        pos_emb_type=args.pos_emb_type,use_layer_emb=args.use_layer_emb,no_memory=args.no_memory,kernel_phi=eluplus1,update_rule=args.update_rule)
    elif args.model=='rnn_tbptt':
        model=VanillaRNNPredictor(d_model=args.d_model,output_size=args.output_size,truncation=args.truncation)
    elif args.model=='lstm_tbptt':
        model=LSTMTBBPTPredictor(d_model=args.d_model,output_size=args.output_size,truncation=args.truncation,n_layers=args.n_layers)
    return model

class NextCharTrainer(BaseTrainer):

    def __init__(self,**kwargs):
        self.config=kwargs['trainer_config']
        self.data_gen=NextCharDataLoader()
        self.data_gen.init(kwargs['seed'],kwargs['env_config'])
        input_size=self.data_gen.input_size
        args=Namespace(**self.config)
        args.truncation=kwargs['global_args'].truncation
        args.output_size=input_size
        self.model=initialize_model(args)
        params_key,self.random_key=jax.random.split(kwargs['key'])
        variables=self.model.init({'params':params_key,'random':self.random_key},jnp.zeros((input_size)))
        self.state, self.params = variables.pop('params')
        def params_sum(params):
            return sum(jax.tree_util.tree_leaves(jax.tree_map(lambda x: np.prod(x.shape),params)))
        print("Number of params: ",params_sum(self.params))

        #Setup optimizer
        self.optimizer=optax.adam(args.lr)
        self.optimizer_state=self.optimizer.init(self.params)

        @jax.jit
        def update_step(inputs, target, opt_state, params, state,random_key):
            def loss(params,inputs,target):
                pred, updated_state = self.model.apply({'params': params, **state},
                                            inputs, mutable=list(state.keys()),rngs={'random':random_key})
                l = -jnp.dot(target,jax.nn.log_softmax(pred))
                l2=-jnp.dot(target,jnp.log2(jax.nn.softmax(pred))) 
                return l,(l2, updated_state)
            (l,(l2, updated_state)), grads = jax.value_and_grad(
                loss, has_aux=True)(params,inputs,target)
            updates, opt_state = self.optimizer.update(grads, opt_state)  # Defined below.
            params = optax.apply_updates(params, updates)
            return l2,opt_state, params, updated_state
        self.update_step=update_step
        self.step_count=0
        self.global_args=kwargs['global_args']
        self.losses=[]
        self.result_data=[]

    def step(self, **kwargs):
        inputs,targets=self.data_gen.step()
        self.random_key=jax.random.split(self.random_key)[0]
        loss,self.optimizer_state,self.params,self.state=self.update_step(inputs,targets,self.optimizer_state,self.params,self.state,self.random_key)
        self.losses.append(loss)
        if (self.step_count+1)%self.global_args.log_interval==0:
            mean_loss=np.mean(self.losses)
            self.result_data.append({'step':self.step_count,'loss':mean_loss})
            self.losses=[]
            metrics={'loss':mean_loss}
        else:
            metrics=None
        self.step_count+=1
        return loss,metrics
    

    def get_summary_table(self):
        return pd.DataFrame(self.result_data).to_json(default_handler=str)
    


class AnimalBehaviourTrainer(BaseTrainer):
    
        def __init__(self,**kwargs):
            self.config=kwargs['trainer_config']
            args=Namespace(**self.config)
            self.global_args=kwargs['global_args']
            args.truncation=self.global_args.truncation
            args.output_size=1
            self.model=initialize_model(args)
            self.data_gen=AnimalBehaviourDataLoader()
            self.data_gen.init(kwargs['seed'],kwargs['env_config'])
            params_key,self.random_key=jax.random.split(kwargs['key'])
            variables=self.model.init({'params':params_key,'random':self.random_key},self.data_gen.O_t)
            self.state, self.params = variables.pop('params')
            def params_sum(params):
                return sum(jax.tree_leaves(jax.tree_map(lambda x: np.prod(x.shape),params)))
            print("Number of params: ",params_sum(self.params))
            
            #Setup optimizer
            self.optimizer=optax.adam(args.lr)
            self.optimizer_state=self.optimizer.init(self.params)

            #Eligibility traces
            self.egb_trace=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x),self.params)

            self.gamma=self.data_gen.config['gamma']
            lamb=args.lamb

            @jax.jit
            def update_step(O_t, R_tplus1, O_tplus1, opt_state, params, state,egb_trace,random_key):
                def value_fn(params,state,O_t, R_tplus1, O_tplus1):
                    V_t, updated_state = self.model.apply({'params': params, **state},
                                                O_t, mutable=list(state.keys()),rngs={'random':random_key})
                    V_tplus1, _ = self.model.apply({'params': params, **updated_state},
                                                O_tplus1, mutable=list(updated_state.keys()),rngs={'random':random_key})
                    
                    delta=R_tplus1+self.gamma*V_tplus1-V_t
                    return V_t[0], (delta,updated_state)

                (V_t, (delta,updated_state)), delta_V_t = jax.value_and_grad(
                    value_fn, has_aux=True)(params,state,O_t, R_tplus1,O_tplus1)
                egb_trace=tree_sum(delta_V_t,(tree_scalar_multiply(egb_trace,
                                lamb*self.gamma)))
                
                grads=tree_scalar_multiply(egb_trace,-delta)
                updates, opt_state = self.optimizer.update(grads, opt_state)  # Defined below.
                params = optax.apply_updates(params, updates)
                return delta**2,V_t,opt_state, params, updated_state,egb_trace
            self.update_step=update_step
            self.step_count=0
            self.losses=[]
            self.predictions=[]
            self.cumulants=[]
            self.predictions_all=[]
            self.cumulants_all=[]

        
        def step(self, **kwargs):
            O_t,O_tplus1,R_tplus1=self.data_gen.step()
            self.random_key=jax.random.split(self.random_key)[0]
            loss,V_t,self.optimizer_state,self.params,self.state,self.egb_trace=self.update_step(O_t, R_tplus1, O_tplus1,self.optimizer_state,self.params,self.state,self.egb_trace,self.random_key)
            self.losses.append(float(loss))
            self.predictions.append(V_t),self.predictions_all.append(V_t)
            self.cumulants.append(R_tplus1),self.cumulants_all.append(R_tplus1)
            if (self.step_count+1)%self.global_args.log_interval==0:
                loss_mean=np.mean(self.losses)
                predictions,cumulants=np.array(self.predictions),np.array(self.cumulants)
                msre,_,_=compute_return_error(cumulants,predictions,self.gamma)
                metrics={'loss':loss_mean,'msre':msre}
                self.predictions=[]
                self.cumulants=[]
                self.losses=[]
            else:
                metrics=None
            self.step_count+=1
            return loss,metrics
        
        def get_summary_table(self):
            msre,return_errors,returns=compute_return_error(self.cumulants_all,self.predictions_all,self.gamma)
            result_table=pd.DataFrame({'msre':msre,'step':np.arange(len(returns)),'returns':returns,'return_errors':return_errors,
                            'predictions':self.predictions_all,'cumulants':self.cumulants_all}).to_json(default_handler=str)
            return result_table
